/*
 * Copyright (c) Facebook, Inc. and its affiliates.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
#pragma once
#include <memory>
#include <optional>
#include <shared_mutex>
#include <sstream>
#include <string>
#include <vector>
#include "velox/experimental/codegen/CodegenExceptions.h"
#include "velox/experimental/codegen/ast/CodegenCtx.h"
#include "velox/experimental/codegen/compiler_utils/LibraryDescriptor.h"
#include "velox/experimental/codegen/udf_manager/ExpressionNullMode.h"
#include "velox/experimental/codegen/udf_manager/UDFManager.h"
#include "velox/type/Type.h"

namespace facebook {
namespace velox {
namespace codegen {

using ASTNodePtr = std::shared_ptr<class ASTNode>;
// A structure that represent the intermediate partial code generated by the
// generateCode.
struct CodeSnippet {
 public:
  CodeSnippet(
      const std::string& outputVarName = "",
      const std::string& code = "")
      : outputVarName_(outputVarName), code_(code) {}

  const std::string& code() const {
    return code_;
  }

  // Return a lambda that executes the code and returns the result value stored
  // in the output variable.
  std::string getAsLambda(const std::string& lambdaName) const {
    return fmt::format(
        "auto {name} = [](){{ {code} return {outputVar}; }}",
        fmt::arg("code", code_),
        fmt::arg("outputVar", outputVarName_),
        fmt::arg("name", lambdaName));
  }

 private:
  // The name of the output variable that holds the expression result value
  std::string outputVarName_;

  // The code that need to be executed
  std::string code_;
};

/// Abstract class for all expression supported in codegen
class ASTNode {
 public:
  explicit ASTNode(const TypePtr& type)
      : type_(std::const_pointer_cast<Type>(type)) {}

  virtual ~ASTNode() = 0;

  // Validate if the expression node is complete and ready for codeGen
  virtual void validate() const = 0;

  // Returns a list of all the children expressions
  virtual const std::vector<ASTNodePtr> children() const = 0;

  // Uses default propagation, can be overridden
  virtual void propagateNullability() {
    defaultNullabilityPropagation();
  }

  // List of include files required for the expression to be executed
  virtual std::vector<std::string> getHeaderFiles() const {
    return {};
  }

  // List of libs required for the expression to be executed
  virtual std::vector<compiler_utils::LibraryDescriptor> getLibs() const {
    return {};
  }

  template <typename T>
  const T* as() const {
    return dynamic_cast<const T*>(this);
  }

  template <typename T>
  T* as() {
    return dynamic_cast<T*>(this);
  }

  virtual bool isConstantExpression() const {
    return false;
  }

  // Generate code with nullable info and return the name of the variable
  // storing the results. Result should be written to output. For complex types
  // output should not be null before written.
  virtual CodeSnippet generateCode(
      CodegenCtx& exprCodegenCtx,
      const std::string& outputVarName) const = 0;

  // Return the null mode of the expression, used in default
  // propagateNullability.
  virtual ExpressionNullMode getNullMode() const = 0;

  // Return the sql data type of the expression
  const velox::Type& type() const {
    return *type_.get();
  }

  /// Return the sql data type of the expression
  const velox::TypePtr typePtr() const {
    return type_;
  }

  /// Check weather the expression is typed
  bool typed() const {
    return type_ != nullptr && type_->kind() != TypeKind::INVALID;
  }

  void validateTyped() const {
    if (!typed()) {
      throw ASTValidationException("ast node not typed");
    }
  }

  // Returns the nullability of the node
  bool maybeNull() const {
    return maybeNull_;
  }

  void markAllInputsNotNullable();

 protected:
  // Set the nullability of the node
  void setMaybeNull(bool maybeNull) {
    maybeNull_ = maybeNull;
  }

 private:
  // The default function that propagates nullability for function call, which
  // depends on FunctionNullMode.
  void defaultNullabilityPropagation() {
    for (auto& child : children()) {
      child->propagateNullability();
    }

    switch (getNullMode()) {
      case ExpressionNullMode::NullInNullOut:
      case ExpressionNullMode::NullableInNullableOut:
        for (auto& child : children()) {
          if (child->maybeNull()) {
            setMaybeNull(true);
            return;
          }
        }
      case ExpressionNullMode::NotNull:
        setMaybeNull(false);
        return;
      case ExpressionNullMode::Custom:
        throw CodegenNotSupported(
            "Error in AST Design, propagateNullability must be overridden");
    }
  }

  // Stores the type of the expression
  std::shared_ptr<velox::Type> type_ = nullptr;

  // Stores the nullability of the node
  bool maybeNull_ = true;
};

/// Node represents input value reference (reference to value/col in the input
/// row)
class InputRefExpr final : public ASTNode {
 public:
  InputRefExpr(const TypePtr& type, const std::string& name, size_t index)
      : ASTNode(type), name_(name), index_(index) {}

  CodeSnippet generateCode(
      CodegenCtx& exprCodegenCtx,
      const std::string& outputVarName) const override;

  const std::vector<ASTNodePtr> children() const override {
    return {};
  }

  ExpressionNullMode getNullMode() const override {
    return ExpressionNullMode::Custom;
  }

  void propagateNullability() override {
    // set in constructor
  }

  /// Set the nullability of the input
  void setMaybeNull(bool maybeNull) {
    ASTNode::setMaybeNull(maybeNull);
  }

  void validate() const override {
    validateTyped();
    if (index_ < 0) {
      throw ASTValidationException(
          "input reference expression expect a positive index");
    }
  }

  /// Return the index of referenced column
  size_t index() const {
    return index_;
  }
  /// Return the column name
  const std::string& name() const {
    return name_;
  }

 private:
  /// Name of the referenced column
  std::string name_;

  // The index of referenced column
  size_t index_;
};

// An expression that combines the children expressions into a row (tuple)
class MakeRowExpression final : public ASTNode {
 public:
  MakeRowExpression(
      const TypePtr& type,
      const std::vector<ASTNodePtr>& children)
      : ASTNode(type), children_(children) {}

  CodeSnippet generateCode(
      CodegenCtx& exprCodegenCtx,
      const std::string& outputVarName) const override;

  const std::vector<ASTNodePtr> children() const override {
    return children_;
  }

  ExpressionNullMode getNullMode() const override {
    return ExpressionNullMode::NotNull;
  }

  void validate() const override {
    validateTyped();
    if (children_.size() == 0) {
      throw ASTValidationException(
          "output expression should have at least one child");
    }
    for (auto& child : children_) {
      child->validate();
    }
  }

  /// Return the size of the output tuple
  size_t width() {
    return children_.size();
  }

 private:
  /// The output expression, a child at index X write to output[X]
  std::vector<ASTNodePtr> children_;
};

// If expression AST node
class IfExpression final : public ASTNode {
 public:
  IfExpression(
      const TypePtr& type,
      ASTNodePtr condition,
      ASTNodePtr thenPart,
      ASTNodePtr elsePart,
      bool isEager = false)
      : ASTNode(type),
        condition_(condition),
        thenPart_(thenPart),
        elsePart_(elsePart),
        isEager_(isEager) {}

  IfExpression(
      const TypePtr& type,
      ASTNodePtr condition,
      ASTNodePtr thenPart,
      bool isEager = false)
      : ASTNode(type),
        condition_(condition),
        thenPart_(thenPart),
        elsePart_(nullptr),
        isEager_(isEager) {}

  void propagateNullability() override;

  CodeSnippet generateCode(
      CodegenCtx& exprCodegenCtx,
      const std::string& outputVarName) const override;

  void validate() const override;

  const std::vector<ASTNodePtr> children() const override {
    return {condition_, thenPart_, elsePart_};
  }

  ExpressionNullMode getNullMode() const override {
    return ExpressionNullMode::Custom;
  }

 private:
  // AST node representing the condition
  ASTNodePtr condition_;

  // AST node representing then part
  ASTNodePtr thenPart_;

  // AST node representing else part
  ASTNodePtr elsePart_;

  // If isEager_ then and else always executed for side effects but only one
  // value returned
  bool isEager_;
};

// Switch expression AST node
class SwitchExpression final : public ASTNode {
 public:
  SwitchExpression(const TypePtr& type, std::vector<ASTNodePtr>& inputs)
      : ASTNode(type), inputs_(inputs) {}

  void propagateNullability() override;

  CodeSnippet generateCode(
      CodegenCtx& exprCodegenCtx,
      const std::string& outputVarName) const override;

  void validate() const override;

  const std::vector<ASTNodePtr> children() const override {
    return inputs_;
  }

  ExpressionNullMode getNullMode() const override {
    return ExpressionNullMode::Custom;
  }

 private:
  // A sequence of WHEN ... THEN ... and maybe a trailing ELSE
  std::vector<ASTNodePtr> inputs_;
};

/// AST node represent a general function call with no code-gen specific
/// optimizations and handling.
class UDFCallExpr final : public ASTNode {
 public:
  UDFCallExpr(
      const TypePtr& type,
      const UDFInformation& udfInformation,
      const std::vector<ASTNodePtr>& children)
      : ASTNode(type), udfInformation_(udfInformation), children_(children) {
    udfInformation.validate(false /*veloxNamesMustBeSet*/);
  }

  void validate() const override;

  const std::vector<ASTNodePtr> children() const override {
    return children_;
  }

  std::vector<std::string> getHeaderFiles() const override {
    if (udfInformation_.hasHeaderFiles()) {
      return udfInformation_.getHeaderFiles();
    }
    return {};
  }

  std::vector<compiler_utils::LibraryDescriptor> getLibs() const override {
    if (udfInformation_.hasLibs()) {
      return udfInformation_.getLibs();
    }
    return {};
  }

  CodeSnippet generateCode(
      CodegenCtx& exprCodegenCtx,
      const std::string& outputVarName) const override;

  ExpressionNullMode getNullMode() const override {
    return udfInformation_.getNullMode();
  }

  bool isNullableOutput() const {
    return udfInformation_.isOptionalOutput();
  }

 protected:
  // Return the function name used in the generated code
  const std::string getFunctionName() const {
    return udfInformation_.getCalledFunctionName();
  }

 private:
  // Information about the called udf are stored in this structure
  const UDFInformation udfInformation_;

  // Function call input arguments
  std::vector<ASTNodePtr> children_;
};

class CoalesceExpr final : public ASTNode {
 public:
  CoalesceExpr(const TypePtr& type, const std::vector<ASTNodePtr>& children)
      : ASTNode(type), children_(children) {}

  void validate() const override {
    validateTyped();
  }

  const std::vector<ASTNodePtr> children() const override {
    return children_;
  }

  void propagateNullability() override;

  CodeSnippet generateCode(
      CodegenCtx& exprCodegenCtx,
      const std::string& outputVarName) const override;

  virtual ExpressionNullMode getNullMode() const override {
    return ExpressionNullMode::Custom;
  }

 private:
  std::vector<ASTNodePtr> children_;
};
} // namespace codegen
} // namespace velox
} // namespace facebook
